"""
This code based on codes from https://github.com/tristandeleu/ntm-one-shot \
                              and https://github.com/kjunelee/MetaOptNet
"""
import numpy as np
import random
import pickle as pkl
from functions import *
import torch
import time
import torch.nn.functional as F

class miniImageNetGenerator(object):

    def __init__(self, data_file, nb_classes, num_user, n_spt,
                  n_qry, max_iter=None):
        super(miniImageNetGenerator, self).__init__()
        self.data_file = data_file
        self.nb_classes = nb_classes
        self.max_iter = max_iter
        self.num_iter = 0
        self.data_dict = self._load_data(self.data_file)
        self.num_user = num_user
        self.n_spt = n_spt
        self.n_qry = n_qry
        self.nb_shards = 2

    def _load_data(self, data_file):
        dataset = self.load_data(data_file)
        data = dataset['data']
        labels = dataset['labels']
        label2ind = self.buildLabelIndex(labels)

        return {key: torch.tensor(data[val]).permute([0,3,1,2]) for (key, val) in label2ind.items()}

    def load_data(self, data_file):
        try:
            with open(data_file, 'rb') as fo:
                data = pkl.load(fo)
            return data
        except:
            with open(data_file, 'rb') as f:
                u = pkl._Unpickler(f)
                u.encoding = 'latin1'
                data = u.load()
            return data

    def buildLabelIndex(self, labels):
        label2inds = {}
        for idx, label in enumerate(labels):
            if label not in label2inds:
                label2inds[label] = []
            label2inds[label].append(idx)

        return label2inds


    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def next(self):
        if (self.max_iter is None) or (self.num_iter < self.max_iter):
            self.num_iter += 1
            x_spt, y_spt, x_qry, y_qry = self.sample(self.nb_classes)

            return (self.num_iter - 1), x_spt, y_spt, x_qry, y_qry
        else:
            raise StopIteration()


    def sample(self, nb_classes):

        t_start = time.time()
        key_list = self.data_dict.keys()
        x_spt = {}; y_spt = {}; x_qry = {}; y_qry = {}
        All_x = torch.zeros(nb_classes * self.num_user * (self.n_spt + self.n_qry), 3, 84, 84) # len = 750
        All_y = torch.zeros(nb_classes * self.num_user * (self.n_spt + self.n_qry))
        const_idxs = [30*i for i in range(21)]

        sampled_class = random.sample(key_list, nb_classes)  # randomly choose k class for an episode
        for clsidx, _class in enumerate(sampled_class):
            _imgs = self.data_dict[_class]  # [600,84,84,3]
            all_idx = set([i for i in range(len(_imgs))])
            _ind1 = random.sample(all_idx, self.num_user * (self.n_spt + self.n_qry)) # sample 10*15=150
            All_x[clsidx * self.num_user * (self.n_spt + self.n_qry):
                      (clsidx + 1) * self.num_user * (self.n_spt + self.n_qry)] = _imgs[_ind1]
            All_y[clsidx * self.num_user * (self.n_spt + self.n_qry):
                      (clsidx + 1) * self.num_user * (self.n_spt + self.n_qry)] = clsidx

        all_idx = set([i for i in range(len(const_idxs)-1)])
        for user in range(self.num_user):
            _ind = random.sample(all_idx, self.nb_shards)
            _ind_S = _ind[0]; _ind_Q = _ind[1]
            all_idx = all_idx - set(_ind)
            x_1 = All_x[const_idxs[_ind_S]:const_idxs[_ind_S + 1]]  # [37(38),3,84,84]
            y_1 = All_y[const_idxs[_ind_S]:const_idxs[_ind_S + 1]]  # [37(38)]
            x_2 = All_x[const_idxs[_ind_Q]:const_idxs[_ind_Q + 1]]  # [37(38),3,84,84]
            y_2 = All_y[const_idxs[_ind_Q]:const_idxs[_ind_Q + 1]]  # [37(38)]

            len_S1, len_S2 = int(len(x_1)/2), int(len(x_2)/2)
            len_Q1, len_Q2 = len(x_1)-len_S1, len(x_2)-len_S2
            x_spt[user] = torch.zeros(len_S1 + len_S2, 3, 84, 84)
            y_spt[user] = torch.zeros(len_S1 + len_S2, dtype=int)
            x_qry[user] = torch.zeros(len_Q1 + len_Q2, 3, 84, 84)
            y_qry[user] = torch.zeros(len_Q1 + len_Q2, dtype=int)
            x_spt[user][:len_S1] = x_1[:len_S1]; y_spt[user][:len_S1] = y_1[:len_S1]
            x_spt[user][len_S1:] = x_2[:len_S2]; y_spt[user][len_S1:] = y_2[:len_S2]
            x_qry[user][:len_Q1] = x_1[len_S1:]; y_qry[user][:len_Q1] = y_1[len_S1:]
            x_qry[user][len_Q1:] = x_2[len_S2:]; y_qry[user][len_Q1:] = y_2[len_S2:]


        t_end = time.time()
        elapsed_time = t_end - t_start


        return x_spt, y_spt, x_qry, y_qry


